/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef FUSION_ENGINE_OPTIMIZER_COMMON_FORMAT_AXIS_UTIL_H_
#define FUSION_ENGINE_OPTIMIZER_COMMON_FORMAT_AXIS_UTIL_H_

#include <functional>
#include <vector>
#include "common/fe_inner_error_codes.h"
#include "common/fe_utils.h"
#include "common/math_util.h"
#include "common/util/op_info_util.h"
#include "graph/compute_graph.h"
#include "graph/debug/ge_attr_define.h"

namespace fe {
enum AxisValueType {
  AXIS_N = 0,
  AXIS_C = 1,
  AXIS_H = 2,
  AXIS_W = 3,
  AXIS_C1 = 4,
  AXIS_C0 = 5,
  AXIS_Co = 6,
  AXIS_D = 7,
  AXIS_G = 8,
  AXIS_BOTTOM = 9
};

using AxisNameNumberMap = std::map<std::string, std::vector<int32_t>>;
const AxisNameNumberMap AXIS_NAME_NUMBER_MAP_5HD{{N_AXIS_NAME, {NC1HWC0_DIM_N}},
                                                 {H_AXIS_NAME, {NC1HWC0_DIM_H}},
                                                 {W_AXIS_NAME, {NC1HWC0_DIM_W}},
                                                 {C_AXIS_NAME, {NC1HWC0_DIM_C1, NC1HWC0_DIM_C0}}};

const AxisNameNumberMap AXIS_NAME_NUMBER_MAP_6HD{{N_AXIS_NAME, {NDC1HWC0_DIM_N}},
                                                 {D_AXIS_NAME, {NDC1HWC0_DIM_D}},
                                                 {H_AXIS_NAME, {NDC1HWC0_DIM_H}},
                                                 {W_AXIS_NAME, {NDC1HWC0_DIM_W}},
                                                 {C_AXIS_NAME, {NDC1HWC0_DIM_C1, NDC1HWC0_DIM_C0}}};

const AxisNameNumberMap AXIS_NAME_NUMBER_MAP_FZ{
    {H_AXIS_NAME, {C1HWNCoC0_DIM_H}}, {W_AXIS_NAME, {C1HWNCoC0_DIM_W}},
};

const AxisNameNumberMap AXIS_NAME_NUMBER_MAP_FZ_3D{
    {D_AXIS_NAME, {C1DHWNCoC0_DIM_D}}, {H_AXIS_NAME, {C1DHWNCoC0_DIM_H}}, {W_AXIS_NAME, {C1DHWNCoC0_DIM_W}},
};

const std::map<ge::Format, AxisNameNumberMap> FORMAT_AXIS_NAME_NUMBER_MAP{
    {ge::FORMAT_NC1HWC0, AXIS_NAME_NUMBER_MAP_5HD},
    {ge::FORMAT_NDC1HWC0, AXIS_NAME_NUMBER_MAP_6HD},
    {ge::FORMAT_FRACTAL_Z, AXIS_NAME_NUMBER_MAP_FZ},
    {ge::FORMAT_FRACTAL_Z_3D, AXIS_NAME_NUMBER_MAP_FZ_3D},
    /* Axis info for 6D is the same as Fractal_Z */
    {ge::FORMAT_C1HWNCoC0, AXIS_NAME_NUMBER_MAP_FZ}};
int64_t DivisionCeiling(int64_t dividend, int64_t divisor);

/* Axis value is arranged as {N,C,H,W,C1,C0,...} */
/* The first parameter is old shape's dimension,
 * second is c0 and third is axis value. */
using GetAxisValueInfoByFormat =
        std::function<Status(const vector<int64_t>&, const uint32_t&, vector<int64_t>&, vector<int64_t>&)>;

using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>;

class AxisUtil {
 public:
  static Status GetAxisValueByOriginFormat(const ge::Format& format, const vector<int64_t>& dim_vec, const uint32_t& c0,
                                           vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static bool HasAxisValueFunc(const ge::Format& format);
  static Status GetOriginAxisAttribute(const ge::OpDesc& op_desc, const ge::GeShape shape,
                                       vector<int64_t>& axis_index_vec);

 private:
  static Status CheckParams(const vector<int64_t>& original_dim_vec, const uint32_t& c0, vector<int64_t>& nd_value,
                            const size_t& dim_default_size);

  static Status GetAxisValueByNCHW(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                   vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByNHWC(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                   vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByNC1HWC0(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                      vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByFz(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                 vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByHWCN(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                   vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByCHWN(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                   vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByND(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                 vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByNDHWC(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                    vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByNCDHW(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                    vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByDHWCN(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                    vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  static Status GetAxisValueByDHWNC(const vector<int64_t>& original_dim_vec, const uint32_t& c0,
                                    vector<int64_t>& axis_value, vector<int64_t>& nd_value);

  /* map of GetAxisValueInfoByFormat, get axis value by different original
   * formats. */
  static const std::map<ge::Format, GetAxisValueInfoByFormatPtr> get_axis_value_func_map;
};
}  // namespace fe

#endif  // FUSION_ENGINE_OPTIMIZER_COMMON_FORMAT_AXIS_UTIL_H_
